{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to Use ENAS/ProxylessNAS in Ten Minutes\n",
    "\n",
    ":label:`sec_proxyless`\n",
    "\n",
    "\n",
    "## What is the Key Idea of ENAS and ProxylessNAS?\n",
    "\n",
    "Traditional reinforcement learning-based neural architecture search learns an architecture controller\n",
    "by iteratively sampling the architecture and training the model to get final reward to update the controller. It is extremely expensive process due to training CNN.\n",
    "\n",
    "![ProxylessNAS](https://raw.githubusercontent.com/zhanghang1989/AutoGluonWebdata/master/docs/tutorial/proxyless.png)\n",
    "\n",
    "Recent work of ENAS and ProxylessNAS construct an over-parameterized network (supernet) and share the weights across different architecture to speed up the search speed. The reward is calculated every few iterations instead of every training period.\n",
    "\n",
    "Import MXNet and AutoGluon:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import autogluon as ag\n",
    "import mxnet as mx\n",
    "import mxnet.gluon.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to Construct a SuperNet\n",
    "\n",
    "Basic NN blocks for CNN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Identity(mx.gluon.HybridBlock):\n",
    "    def hybrid_forward(self, F, x):\n",
    "        return x\n",
    "    \n",
    "class ConvBNReLU(mx.gluon.HybridBlock):\n",
    "    def __init__(self, in_channels, channels, kernel, stride):\n",
    "        super().__init__()\n",
    "        padding = (kernel - 1) // 2\n",
    "        self.conv = nn.Conv2D(channels, kernel, stride, padding, in_channels=in_channels)\n",
    "        self.bn = nn.BatchNorm(in_channels=channels)\n",
    "        self.relu = nn.Activation('relu')\n",
    "    def hybrid_forward(self, F, x):\n",
    "        return self.relu(self.bn(self.conv(x)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### AutoGluon ENAS Unit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from autogluon.contrib.enas import *\n",
    "\n",
    "@enas_unit()\n",
    "class ResUnit(mx.gluon.HybridBlock):\n",
    "    def __init__(self, in_channels, channels, hidden_channels, kernel, stride):\n",
    "        super().__init__()\n",
    "        self.conv1 = ConvBNReLU(in_channels, hidden_channels, kernel, stride)\n",
    "        self.conv2 = ConvBNReLU(hidden_channels, channels, kernel, 1)\n",
    "        if in_channels == channels and stride == 1:\n",
    "            self.shortcut = Identity()\n",
    "        else:\n",
    "            self.shortcut = nn.Conv2D(channels, 1, stride, in_channels=in_channels)\n",
    "    def hybrid_forward(self, F, x):\n",
    "        return self.conv2(self.conv1(x)) + self.shortcut(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### AutoGluon Sequntial\n",
    "\n",
    "Creating a ENAS network using Sequential Block:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"136pt\" height=\"620pt\"\n",
       " viewBox=\"0.00 0.00 136.00 620.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 616)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-616 132,-616 132,4 -4,4\"/>\n",
       "<!-- 0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>0</title>\n",
       "<polygon fill=\"#bebada\" stroke=\"#b2dfee\" points=\"128,-612 0,-612 0,-576 128,-576 128,-612\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-590.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(H4.K3.S2.)</text>\n",
       "</g>\n",
       "<!-- 1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>1</title>\n",
       "<polygon fill=\"#bebada\" stroke=\"#b2dfee\" points=\"128,-540 0,-540 0,-504 128,-504 128,-540\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-518.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(K3.H8.S2.)</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>0&#45;&gt;1</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-575.8314C64,-568.131 64,-558.9743 64,-550.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-550.4132 64,-540.4133 60.5001,-550.4133 67.5001,-550.4132\"/>\n",
       "</g>\n",
       "<!-- 2 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2</title>\n",
       "<polygon fill=\"#bebada\" stroke=\"#b2dfee\" points=\"128,-468 0,-468 0,-432 128,-432 128,-468\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-446.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(K3.H8.S2.)</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;2 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>1&#45;&gt;2</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-503.8314C64,-496.131 64,-486.9743 64,-478.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-478.4132 64,-468.4133 60.5001,-478.4133 67.5001,-478.4132\"/>\n",
       "</g>\n",
       "<!-- 3 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>3</title>\n",
       "<polygon fill=\"#bebada\" stroke=\"#b2dfee\" points=\"128,-396 0,-396 0,-360 128,-360 128,-396\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-374.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(K3.H8.S1.)</text>\n",
       "</g>\n",
       "<!-- 2&#45;&gt;3 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>2&#45;&gt;3</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-431.8314C64,-424.131 64,-414.9743 64,-406.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-406.4132 64,-396.4133 60.5001,-406.4133 67.5001,-406.4132\"/>\n",
       "</g>\n",
       "<!-- 4 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>4</title>\n",
       "<polygon fill=\"#bebada\" stroke=\"#b2dfee\" points=\"128,-324 0,-324 0,-288 128,-288 128,-324\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-302.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(K3.H8.S1.)</text>\n",
       "</g>\n",
       "<!-- 3&#45;&gt;4 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>3&#45;&gt;4</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-359.8314C64,-352.131 64,-342.9743 64,-334.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-334.4132 64,-324.4133 60.5001,-334.4133 67.5001,-334.4132\"/>\n",
       "</g>\n",
       "<!-- 5 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>5</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"123,-252 5,-252 5,-216 123,-216 123,-252\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">GlobalAvgPool2D</text>\n",
       "</g>\n",
       "<!-- 4&#45;&gt;5 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>4&#45;&gt;5</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-287.8314C64,-280.131 64,-270.9743 64,-262.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-262.4132 64,-252.4133 60.5001,-262.4133 67.5001,-262.4132\"/>\n",
       "</g>\n",
       "<!-- 6 -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>6</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"91,-180 37,-180 37,-144 91,-144 91,-180\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Flatten</text>\n",
       "</g>\n",
       "<!-- 5&#45;&gt;6 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>5&#45;&gt;6</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-215.8314C64,-208.131 64,-198.9743 64,-190.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-190.4132 64,-180.4133 60.5001,-190.4133 67.5001,-190.4132\"/>\n",
       "</g>\n",
       "<!-- 7 -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>7</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"101,-108 27,-108 27,-72 101,-72 101,-108\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Activation</text>\n",
       "</g>\n",
       "<!-- 6&#45;&gt;7 -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>6&#45;&gt;7</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-143.8314C64,-136.131 64,-126.9743 64,-118.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-118.4132 64,-108.4133 60.5001,-118.4133 67.5001,-118.4132\"/>\n",
       "</g>\n",
       "<!-- 8 -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>8</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"91,-36 37,-36 37,0 91,0 91,-36\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Dense</text>\n",
       "</g>\n",
       "<!-- 7&#45;&gt;8 -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>7&#45;&gt;8</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-71.8314C64,-64.131 64,-54.9743 64,-46.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-46.4132 64,-36.4133 60.5001,-46.4133 67.5001,-46.4132\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x7f8805e33b50>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mynet = ENAS_Sequential(\n",
    "    ResUnit(1, 8, hidden_channels=ag.space.Categorical(4, 8), kernel=ag.space.Categorical(3, 5), stride=2),\n",
    "    ResUnit(8, 8, hidden_channels=8, kernel=ag.space.Categorical(3, 5), stride=2),\n",
    "    ResUnit(8, 16, hidden_channels=8, kernel=ag.space.Categorical(3, 5), stride=2),\n",
    "    ResUnit(16, 16, hidden_channels=8, kernel=ag.space.Categorical(3, 5), stride=1, with_zero=True),\n",
    "    ResUnit(16, 16, hidden_channels=8, kernel=ag.space.Categorical(3, 5), stride=1, with_zero=True),\n",
    "    nn.GlobalAvgPool2D(),\n",
    "    nn.Flatten(),\n",
    "    nn.Activation('relu'),\n",
    "    nn.Dense(10, in_units=16),\n",
    ")\n",
    "\n",
    "mynet.initialize()\n",
    "\n",
    "mynet.graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate Network Latency and Define Reward Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = mx.nd.random.uniform(shape=(1, 1, 28, 28))\n",
    "y = mynet.evaluate_latency(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Show the latencies:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average latency is 7.55 ms, latency of the current architecture is 8.82 ms\n"
     ]
    }
   ],
   "source": [
    "print('Average latency is {:.2f} ms, latency of the current architecture is {:.2f} ms'.format(mynet.avg_latency, mynet.latency))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also provide number of params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8714"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mynet.nparams"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the reward function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_fn = lambda metric, net: metric * ((net.avg_latency / net.latency) ** 0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start the Training\n",
    "\n",
    "Construct experiment scheduler, which automatically creates an RL controller based on user-defined search space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "get_built_in_dataset mnist\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading /home/ec2-user/.mxnet/datasets/mnist/train-images-idx3-ubyte.gz from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/train-images-idx3-ubyte.gz...\n",
      "Downloading /home/ec2-user/.mxnet/datasets/mnist/train-labels-idx1-ubyte.gz from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/train-labels-idx1-ubyte.gz...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "get_built_in_dataset mnist\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading /home/ec2-user/.mxnet/datasets/mnist/t10k-images-idx3-ubyte.gz from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/t10k-images-idx3-ubyte.gz...\n",
      "Downloading /home/ec2-user/.mxnet/datasets/mnist/t10k-labels-idx1-ubyte.gz from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/mnist/t10k-labels-idx1-ubyte.gz...\n"
     ]
    }
   ],
   "source": [
    "scheduler = ENAS_Scheduler(mynet, train_set='mnist',\n",
    "                           reward_fn=reward_fn, batch_size=128, num_gpus=1,\n",
    "                           warmup_epochs=0, epochs=1, controller_lr=3e-3,\n",
    "                           plot_frequency=10, update_arch_frequency=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start the training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e37e35e8baa0493d83e4a22b3d78adcd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value=''))), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce3f78738c2e45fcad7c803ce3251cde",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, max=468), HTML(value=''))), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e2c6df493490440c82cb3dfb7eadbed3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, max=79), HTML(value=''))), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "scheduler.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The resulting architecture is:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"136pt\" height=\"476pt\"\n",
       " viewBox=\"0.00 0.00 136.00 476.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 472)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-472 132,-472 132,4 -4,4\"/>\n",
       "<!-- 0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>0</title>\n",
       "<polygon fill=\"#fdb462\" stroke=\"#b2dfee\" points=\"128,-468 0,-468 0,-432 128,-432 128,-468\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-446.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(H8.K5.S2.)</text>\n",
       "</g>\n",
       "<!-- 1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>1</title>\n",
       "<polygon fill=\"#fdb462\" stroke=\"#b2dfee\" points=\"128,-396 0,-396 0,-360 128,-360 128,-396\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-374.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(K5.H8.S2.)</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>0&#45;&gt;1</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-431.8314C64,-424.131 64,-414.9743 64,-406.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-406.4132 64,-396.4133 60.5001,-406.4133 67.5001,-406.4132\"/>\n",
       "</g>\n",
       "<!-- 2 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2</title>\n",
       "<polygon fill=\"#bebada\" stroke=\"#b2dfee\" points=\"128,-324 0,-324 0,-288 128,-288 128,-324\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-302.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(K3.H8.S2.)</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;2 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>1&#45;&gt;2</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-359.8314C64,-352.131 64,-342.9743 64,-334.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-334.4132 64,-324.4133 60.5001,-334.4133 67.5001,-334.4132\"/>\n",
       "</g>\n",
       "<!-- 5 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>5</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"123,-252 5,-252 5,-216 123,-216 123,-252\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">GlobalAvgPool2D</text>\n",
       "</g>\n",
       "<!-- 2&#45;&gt;5 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>2&#45;&gt;5</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-287.8314C64,-280.131 64,-270.9743 64,-262.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-262.4132 64,-252.4133 60.5001,-262.4133 67.5001,-262.4132\"/>\n",
       "</g>\n",
       "<!-- 6 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>6</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"91,-180 37,-180 37,-144 91,-144 91,-180\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Flatten</text>\n",
       "</g>\n",
       "<!-- 5&#45;&gt;6 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>5&#45;&gt;6</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-215.8314C64,-208.131 64,-198.9743 64,-190.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-190.4132 64,-180.4133 60.5001,-190.4133 67.5001,-190.4132\"/>\n",
       "</g>\n",
       "<!-- 7 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>7</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"101,-108 27,-108 27,-72 101,-72 101,-108\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Activation</text>\n",
       "</g>\n",
       "<!-- 6&#45;&gt;7 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>6&#45;&gt;7</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-143.8314C64,-136.131 64,-126.9743 64,-118.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-118.4132 64,-108.4133 60.5001,-118.4133 67.5001,-118.4132\"/>\n",
       "</g>\n",
       "<!-- 8 -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>8</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"91,-36 37,-36 37,0 91,0 91,-36\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Dense</text>\n",
       "</g>\n",
       "<!-- 7&#45;&gt;8 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>7&#45;&gt;8</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-71.8314C64,-64.131 64,-54.9743 64,-46.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-46.4132 64,-36.4133 60.5001,-46.4133 67.5001,-46.4132\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x7f8804981c90>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mynet.graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Change the reward trade-off:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "get_built_in_dataset mnist\n",
      "get_built_in_dataset mnist\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0eac4dca035740c59acaee69e81aa88a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value=''))), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "785ca3c252134dc18eab4c3f1c596673",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, max=468), HTML(value=''))), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "35cacaac34274420af4910d4c4e55e26",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, max=79), HTML(value=''))), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "reward_fn = lambda metric, net: metric * ((net.avg_latency / net.latency) ** 0.8)\n",
    "mynet.initialize(force_reinit=True)\n",
    "scheduler = ENAS_Scheduler(mynet, train_set='mnist',\n",
    "                           reward_fn=reward_fn, batch_size=128, num_gpus=1,\n",
    "                           warmup_epochs=0, epochs=1, controller_lr=3e-3,\n",
    "                           plot_frequency=10, update_arch_frequency=5)\n",
    "scheduler.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The resulting architecture is:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"136pt\" height=\"476pt\"\n",
       " viewBox=\"0.00 0.00 136.00 476.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 472)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-472 132,-472 132,4 -4,4\"/>\n",
       "<!-- 0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>0</title>\n",
       "<polygon fill=\"#fdb462\" stroke=\"#b2dfee\" points=\"128,-468 0,-468 0,-432 128,-432 128,-468\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-446.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(H8.K5.S2.)</text>\n",
       "</g>\n",
       "<!-- 1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>1</title>\n",
       "<polygon fill=\"#fdb462\" stroke=\"#b2dfee\" points=\"128,-396 0,-396 0,-360 128,-360 128,-396\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-374.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(K5.H8.S2.)</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>0&#45;&gt;1</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-431.8314C64,-424.131 64,-414.9743 64,-406.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-406.4132 64,-396.4133 60.5001,-406.4133 67.5001,-406.4132\"/>\n",
       "</g>\n",
       "<!-- 2 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2</title>\n",
       "<polygon fill=\"#fdb462\" stroke=\"#b2dfee\" points=\"128,-324 0,-324 0,-288 128,-288 128,-324\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-302.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">ResUnit(K5.H8.S2.)</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;2 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>1&#45;&gt;2</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-359.8314C64,-352.131 64,-342.9743 64,-334.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-334.4132 64,-324.4133 60.5001,-334.4133 67.5001,-334.4132\"/>\n",
       "</g>\n",
       "<!-- 5 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>5</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"123,-252 5,-252 5,-216 123,-216 123,-252\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">GlobalAvgPool2D</text>\n",
       "</g>\n",
       "<!-- 2&#45;&gt;5 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>2&#45;&gt;5</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-287.8314C64,-280.131 64,-270.9743 64,-262.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-262.4132 64,-252.4133 60.5001,-262.4133 67.5001,-262.4132\"/>\n",
       "</g>\n",
       "<!-- 6 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>6</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"91,-180 37,-180 37,-144 91,-144 91,-180\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Flatten</text>\n",
       "</g>\n",
       "<!-- 5&#45;&gt;6 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>5&#45;&gt;6</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-215.8314C64,-208.131 64,-198.9743 64,-190.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-190.4132 64,-180.4133 60.5001,-190.4133 67.5001,-190.4132\"/>\n",
       "</g>\n",
       "<!-- 7 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>7</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"101,-108 27,-108 27,-72 101,-72 101,-108\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Activation</text>\n",
       "</g>\n",
       "<!-- 6&#45;&gt;7 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>6&#45;&gt;7</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-143.8314C64,-136.131 64,-126.9743 64,-118.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-118.4132 64,-108.4133 60.5001,-118.4133 67.5001,-118.4132\"/>\n",
       "</g>\n",
       "<!-- 8 -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>8</title>\n",
       "<polygon fill=\"#b2dfee\" stroke=\"#b2dfee\" points=\"91,-36 37,-36 37,0 91,0 91,-36\"/>\n",
       "<text text-anchor=\"middle\" x=\"64\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Dense</text>\n",
       "</g>\n",
       "<!-- 7&#45;&gt;8 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>7&#45;&gt;8</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M64,-71.8314C64,-64.131 64,-54.9743 64,-46.4166\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"67.5001,-46.4132 64,-36.4133 60.5001,-46.4133 67.5001,-46.4132\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x7f8805af4f50>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mynet.graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reference\n",
    "\n",
    "[1] Efficient Neural Architecture Search via Parameter Sharing\n",
    "    H Pham, MY Guan, B Zoph, QV Le, J Dean\n",
    "    *International Conference on Machine Learning (ICML)*\n",
    "\n",
    "[3] ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware\n",
    "    Han Cai, Ligeng Zhu, Song Han\n",
    "    *International Conference on Learning Representations (ICLR)*, 2019."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_autogluon",
   "language": "python",
   "name": "conda_autogluon"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
